import abc
import logging
from contextlib import contextmanager, asynccontextmanager
from typing import Any, List, Dict, Optional, Type
import pydantic

from opik import exceptions


LOGGER = logging.getLogger(__name__)


class OpikBaseModel(abc.ABC):
    """
    This class serves as an interface to LLMs.

    If you want to implement a custom LLM provider in evaluation metrics,
    you should inherit from this class.
    """

    def __init__(self, model_name: str):
        """
        Initializes the base model with a given model name.

        Args:
            model_name: The name of the LLM to be used.
        """
        self.model_name = model_name

    @abc.abstractmethod
    def generate_string(
        self,
        input: str,
        response_format: Optional[Type[pydantic.BaseModel]] = None,
        **kwargs: Any,
    ) -> str:
        """
        Simplified interface to generate a string output from the model.

        Args:
            input: The input string based on which the model will generate the output.
            kwargs: Additional arguments that may be used by the model for string generation.

        Returns:
            str: The generated string output.
        """
        pass

    @abc.abstractmethod
    def generate_provider_response(
        self, messages: List[Dict[str, Any]], **kwargs: Any
    ) -> Any:
        """
        Do not use this method directly. It is intended to be used within `get_provider_response()` method.

        Generate a provider-specific response. Can be used to interface with
        the underlying model provider (e.g., OpenAI, Anthropic) and get raw output.

        Args:
            messages: A list of messages to be sent to the model, should be a list of dictionaries with the keys
            kwargs: arguments required by the provider to generate a response.

        Returns:
            Any: The response from the model provider, which can be of any type depending on the use case and LLM.
        """
        pass

    async def agenerate_string(
        self,
        input: str,
        response_format: Optional[Type[pydantic.BaseModel]] = None,
        **kwargs: Any,
    ) -> str:
        """
        Simplified interface to generate a string output from the model. Async version.

        Args:
            input: The input string based on which the model will generate the output.
            kwargs: Additional arguments that may be used by the model for string generation.

        Returns:
            str: The generated string output.
        """
        raise NotImplementedError("Async generation not implemented for this provider")

    async def agenerate_provider_response(
        self, messages: List[Dict[str, Any]], **kwargs: Any
    ) -> Any:
        """
        Do not use this method directly. It is intended to be used within `aget_provider_response()` method.

        Generate a provider-specific response. Can be used to interface with
        the underlying model provider (e.g., OpenAI, Anthropic) and get raw output.
        Async version.

        Args:
            messages: A list of messages to be sent to the model, should be a list of dictionaries with the keys
                "content" and "role".
            kwargs: arguments required by the provider to generate a response.

        Returns:
            Any: The response from the model provider, which can be of any type depending on the use case and LLM.
        """
        raise NotImplementedError("Async generation not implemented for this provider")


@contextmanager
def get_provider_response(
    model_provider: OpikBaseModel, messages: List[Dict[str, Any]], **kwargs: Any
) -> Any:
    """
    Provides a context manager for getting and managing the response from a
    model provider. Ensures that errors during the interaction with the model
    provider are handled appropriately and logged.

    Args:
        model_provider: Instance of a class derived from `OpikBaseModel`
            responsible for interfacing with the model.
        messages: List of dictionaries containing the messages or inputs to be
            passed to the model.
        **kwargs: Additional keyword arguments to customize the generation of
            the model responses.

    Yields:
        Any: The response generated by the model provider.

    Raises:
        exceptions.BaseLLMError: If the response generation from the model provider
            fails due to an exception.
    """
    try:
        yield model_provider.generate_provider_response(messages, **kwargs)
    except Exception as e:
        LOGGER.error("Failed to call LLM provider, reason: %s", e)
        raise exceptions.BaseLLMError(str(e))


@asynccontextmanager
async def aget_provider_response(
    model_provider: OpikBaseModel, messages: List[Dict[str, Any]], **kwargs: Any
) -> Any:
    """
    Asynchronous context manager for getting a response from a model provider.

    This function asynchronously interacts with the specified `model_provider` to
    generate a response based on the given list of `messages` and additional
    optional keyword arguments. If an error occurs during this process, it is
    logged, and a custom exception is raised.

    Args:
        model_provider: The model provider from which to request
            the response.
        messages: A list of dictionaries containing the
            messages for the model provider to process.
        **kwargs: Additional keyword arguments passed to the model provider's
            response generation method.

    Yields:
        Any: The response generated asynchronously by the model provider.

    Raises:
        exceptions.BaseLLMError: If there is an error during the asynchronous
            interaction with the model provider.
    """
    try:
        response = await model_provider.agenerate_provider_response(
            messages=messages, **kwargs
        )
        yield response
    except Exception as e:
        LOGGER.error("Failed to call LLM provider asynchronously, reason: %s", e)
        raise exceptions.BaseLLMError(str(e))


def check_model_output_string(output: Optional[str]) -> str:
    """
    Checks the output of a model and verifies that it is not None.

    This function ensures that the output returned from a language model (LLM) has a valid, non-null value.
    If the output is found to be None, an error is raised with a detailed message. This can help in
    debugging issues related to incorrect environment configuration or missing API keys.

    Args:
        output: The output string generated by the language model to be validated.

    Returns:
        The output of the language model that was validated.

    Raises:
        exceptions.BaseLLMError: Raised if the output is evaluated to None. The error message contains suggestions to
        verify environment configurations and check model API key availability.
    """
    if output is None:
        raise exceptions.BaseLLMError(
            "Received None as the output from the LLM. Please verify your environment configuration "
            "and ensure that the API keys for the models in use (e.g., OPENAI_API_KEY) are set correctly."
        )

    return output
